{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fecb2fb5-b87f-4428-8175-e3a46fe77371",
   "metadata": {},
   "source": [
    "## Tutorial: Optimizing a Prompt\n",
    "\n",
    "![TextGrad](https://github.com/vinid/data/blob/master/logo_full.png?raw=true)\n",
    "\n",
    "An autograd engine -- for textual gradients!\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zou-group/TextGrad/blob/main/examples/notebooks/Prompt-Optimization.ipynb)\n",
    "[![GitHub license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/)\n",
    "[![Arxiv](https://img.shields.io/badge/arXiv-2406.07496-B31B1B.svg)](https://arxiv.org/abs/2406.07496)\n",
    "[![Documentation Status](https://readthedocs.org/projects/textgrad/badge/?version=latest)](https://textgrad.readthedocs.io/en/latest/?badge=latest)\n",
    "[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/textgrad)](https://pypi.org/project/textgrad/)\n",
    "[![PyPI](https://img.shields.io/pypi/v/textgrad)](https://pypi.org/project/textgrad/)\n",
    "\n",
    "**Objectives:**\n",
    "\n",
    "* In this tutorial, we will run prompt optimization.\n",
    "\n",
    "**Requirements:**\n",
    "\n",
    "* You need to have an OpenAI API key to run this tutorial. This should be set as an environment variable as OPENAI_API_KEY.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7add4547-4278-411b-a827-79be521851f1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-11T19:30:34.029594610Z",
     "start_time": "2024-06-11T19:30:34.024175489Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "!pip install textgrad # you might need to restart the notebook after installing textgrad\n",
    "\n",
    "import argparse\n",
    "import concurrent\n",
    "from dotenv import load_dotenv\n",
    "from tqdm import tqdm\n",
    "import textgrad as tg\n",
    "from textgrad.tasks import load_task\n",
    "import numpy as np\n",
    "import random\n",
    "load_dotenv(override=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a459a37-7446-4c4a-a7e0-38182b5dbd3e",
   "metadata": {},
   "source": [
    "Let's first define some support functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1ccc3b21bf9ddc48",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-11T19:30:42.098338405Z",
     "start_time": "2024-06-11T19:30:42.093473103Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "649e06aef34d0990",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def eval_sample(item, eval_fn, model):\n",
    "    \"\"\"\n",
    "    This function allows us to evaluate if an answer to a question in the prompt is a good answer.\n",
    "\n",
    "    \"\"\"\n",
    "    x, y = item\n",
    "    x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n",
    "    y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n",
    "    response = model(x)\n",
    "    try:\n",
    "        eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n",
    "        return int(eval_output_variable.value)\n",
    "    except:\n",
    "        eval_output_variable = eval_fn([x, y, response])\n",
    "        eval_output_parsed = eval_fn.parse_output(eval_output_variable)\n",
    "        return int(eval_output_parsed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9559a31e07e54d7f",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def eval_dataset(test_set, eval_fn, model, max_samples: int=None):\n",
    "    if max_samples is None:\n",
    "        max_samples = len(test_set)\n",
    "    accuracy_list = []\n",
    "    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:\n",
    "        futures = []\n",
    "        for _, sample in enumerate(test_set):\n",
    "            \n",
    "            future = executor.submit(eval_sample, sample, eval_fn, model)\n",
    "            futures.append(future)\n",
    "            if len(futures) >= max_samples:\n",
    "                break\n",
    "        tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)\n",
    "        for future in tqdm_loader:\n",
    "            acc_item = future.result()\n",
    "            accuracy_list.append(acc_item)\n",
    "            tqdm_loader.set_description(f\"Accuracy: {np.mean(accuracy_list)}\")\n",
    "    return accuracy_list "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4ea732b7edf34eb9",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def run_validation_revert(system_prompt: tg.Variable, results, model, eval_fn, val_set):\n",
    "    val_performance = np.mean(eval_dataset(val_set, eval_fn, model))\n",
    "    previous_performance = np.mean(results[\"validation_acc\"][-1])\n",
    "    print(\"val_performance: \", val_performance)\n",
    "    print(\"previous_performance: \", previous_performance)\n",
    "    previous_prompt = results[\"prompt\"][-1]\n",
    "    \n",
    "    if val_performance < previous_performance:\n",
    "        print(f\"rejected prompt: {system_prompt.value}\")\n",
    "        system_prompt.set_value(previous_prompt)\n",
    "        val_performance = previous_performance\n",
    "\n",
    "    results[\"validation_acc\"].append(val_performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e69f8431-661c-42f8-b7fc-efccea588a03",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train/Val/Test Set Lengths:  83 83 84\n"
     ]
    }
   ],
   "source": [
    "set_seed(12)\n",
    "llm_api_eval = tg.get_engine(engine_name=\"gpt-4o\")\n",
    "llm_api_test = tg.get_engine(engine_name=\"gpt-3.5-turbo-0125\")\n",
    "tg.set_backward_engine(llm_api_eval, override=True)\n",
    "\n",
    "# Load the data and the evaluation function\n",
    "train_set, val_set, test_set, eval_fn = load_task(\"BBH_object_counting\", evaluation_api=llm_api_eval)\n",
    "print(\"Train/Val/Test Set Lengths: \", len(train_set), len(val_set), len(test_set))\n",
    "STARTING_SYSTEM_PROMPT = train_set.get_task_description()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f40b576c-4ba0-4e6e-b3ed-81eb44524676",
   "metadata": {},
   "source": [
    "This is the system prompt we are going to start from:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ed3261-6f9d-4906-8c4b-a3ad570f5950",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(STARTING_SYSTEM_PROMPT)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f7544127-38e0-4c74-8632-003efcc645ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.7142857142857143: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1438.83it/s]\n",
      "Accuracy: 0.6867469879518072: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:00<00:00, 1318.42it/s]\n"
     ]
    }
   ],
   "source": [
    "train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)\n",
    "\n",
    "\n",
    "# Testing the 0-shot performance of the evaluation engine\n",
    "system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n",
    "                            requires_grad=True, \n",
    "                            role_description=\"system prompt to the language model\")\n",
    "model_evaluation = tg.BlackboxLLM(llm_api_eval, system_prompt)\n",
    "\n",
    "system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n",
    "                            requires_grad=True,\n",
    "                            role_description=\"structured system prompt to a somewhat capable language model that specifies the behavior and strategies for the QA task\")\n",
    "model = tg.BlackboxLLM(llm_api_test, system_prompt)\n",
    "\n",
    "optimizer = tg.TextualGradientDescent(engine=llm_api_eval, parameters=[system_prompt])\n",
    "\n",
    "results = {\"test_acc\": [], \"prompt\": [], \"validation_acc\": []}\n",
    "results[\"test_acc\"].append(eval_dataset(test_set, eval_fn, model))\n",
    "results[\"validation_acc\"].append(eval_dataset(val_set, eval_fn, model))\n",
    "results[\"prompt\"].append(system_prompt.get_value())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b3807736-1d81-4349-95db-257c20110d1a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.5542168674698795: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:21<00:00,  3.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.5542168674698795\n",
      "previous_performance:  0.6867469879518072\n",
      "rejected prompt: You will answer a reasoning question. Provide the final numerical answer directly. Your response should be a single numerical value. Do not include any explanation or additional text, only the numerical answer.\n",
      "sys prompt:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.7142857142857143: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1201.24it/s]\n",
      "Accuracy: 0.8313253012048193: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:43<00:00,  1.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.8313253012048193\n",
      "previous_performance:  0.6867469879518072\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the total number. Avoid adding any extra information or context that is not directly related to the total count. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8214285714285714: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:43<00:00,  1.94it/s]\n",
      "Accuracy: 0.5542168674698795: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:46<00:00,  1.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.5542168674698795\n",
      "previous_performance:  0.8313253012048193\n",
      "rejected prompt: You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. \n",
      "\n",
      "1. Restate the question to ensure clarity.\n",
      "2. List each item you count on a new line using bullet points for clarity. If there are multiple items of the same type, list them together and indicate the quantity.\n",
      "3. Ensure the numbering and items in the list are consistent and free from typographical errors.\n",
      "4. After listing all items, count the total number of items and provide the final count in the specified format.\n",
      "5. Verify the total number by cross-checking with the list.\n",
      "6. If there are potential errors or ambiguities in the list, acknowledge them and request additional details if necessary.\n",
      "\n",
      "The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the total number. Avoid adding any extra information or context that is not directly related to the total count. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1638.96it/s]\n",
      "Accuracy: 0.6867469879518072: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:18<00:00,  4.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.6867469879518072\n",
      "previous_performance:  0.8313253012048193\n",
      "rejected prompt: You will answer a reasoning question that involves counting items in a list. Directly compute the total number of items and provide the final count. Ensure you understand the context of the question and focus on counting the items accurately. Avoid adding any extra information or context that is not directly related to the total count. Ensure your response follows the format 'Answer: $VALUE' with no additional text. If the input contains unexpected characters or is malformed, correct the input and provide a coherent response. Ensure the final line contains only the answer in the required format.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the total number. Avoid adding any extra information or context that is not directly related to the total count. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1043.66it/s]\n",
      "Training step 3. Epoch 0: : 3it [07:29, 149.70s/it]\n",
      "Accuracy: 0.4819277108433735: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:34<00:00,  2.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.4819277108433735\n",
      "previous_performance:  0.8313253012048193\n",
      "rejected prompt: Answer a reasoning question that involves counting items in a list. Focus on counting only the items relevant to the question. Directly provide the total count without listing each item. Ensure the final line contains only the answer in the format 'Answer: $VALUE' where VALUE is a numerical value. Double-check your count for accuracy before providing the final number. If you encounter any ambiguous items or quantities, make a note of them and proceed with the calculation based on the clear items.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the total number. Avoid adding any extra information or context that is not directly related to the total count. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1660.49it/s]\n",
      "Accuracy: 0.3132530120481928: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:16<00:00,  5.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.3132530120481928\n",
      "previous_performance:  0.8313253012048193\n",
      "rejected prompt: You will answer a reasoning question that involves counting items in a list. Ensure you understand the context of the question, such as identifying specific categories of items (e.g., vegetables) before counting. Count the items and provide the total number. Be cautious of items that might be similar or easily confused, and ensure you are counting the correct items based on the context. If the query is ambiguous or could be interpreted in multiple ways, provide a brief explanation of your reasoning or ask for clarification. Ensure the response is a single numerical value without any additional text or formatting.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the total number. Avoid adding any extra information or context that is not directly related to the total count. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1674.21it/s]\n",
      "Accuracy: 0.891566265060241: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [01:00<00:00,  1.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.891566265060241\n",
      "previous_performance:  0.8313253012048193\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8571428571428571: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [01:00<00:00,  1.38it/s]\n",
      "Accuracy: 0.8313253012048193: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:43<00:00,  1.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.8313253012048193\n",
      "previous_performance:  0.891566265060241\n",
      "rejected prompt: You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points with a hyphen (-) for each item. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, state the total number of items in the format: 'The total number of items listed is X.' Ensure the final line contains only the answer in the required format: 'Answer: $VALUE' where VALUE is a numerical value. Double-check your arithmetic to ensure the sum of the counts is correct.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1432.86it/s]\n",
      "Training step 3. Epoch 1: : 3it [08:16, 165.36s/it]\n",
      "Accuracy: 0.6144578313253012: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:18<00:00,  4.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.6144578313253012\n",
      "previous_performance:  0.891566265060241\n",
      "rejected prompt: You will answer a reasoning question that involves counting items in a list. Provide only the final count of items without listing intermediate steps. Ensure you understand the context of each item and its relevance to the total count. If there is any ambiguity, ask a clarifying question. Present the numerical value clearly and directly, without any surrounding text or context. Ensure the final answer is a single numerical value without any additional text. Do not repeat the final answer; provide it only once in the specified format. Avoid using ellipses, parentheses, or any other punctuation that could complicate the response. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1608.54it/s]\n",
      "Accuracy: 0.5301204819277109: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:16<00:00,  5.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.5301204819277109\n",
      "previous_performance:  0.891566265060241\n",
      "rejected prompt: Count the items and provide the total number as a single integer. Do not include any additional text, explanations, or lists in your response. Ensure the answer is a numerical value only.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1517.28it/s]\n",
      "Accuracy: 0.8433734939759037: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [01:20<00:00,  1.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.8433734939759037\n",
      "previous_performance:  0.891566265060241\n",
      "rejected prompt: You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). Ensure that the numbering is sequential and does not skip any numbers. Avoid using ellipses or incomplete names; list each item explicitly. If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. Ensure that each item listed is distinct and not a duplicate. After listing the items, count them explicitly and verify that the total matches the number of distinct items listed. Explicitly show all mathematical operations in your reasoning. For example, if you are adding quantities, write out the full addition process. Before providing the final answer, double-check that the total count is accurate and matches the items listed. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1705.01it/s]\n",
      "Accuracy: 0.7710843373493976: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:27<00:00,  3.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_performance:  0.7710843373493976\n",
      "previous_performance:  0.891566265060241\n",
      "rejected prompt: You will answer a reasoning question that involves counting items in a list. Provide a concise answer without detailed explanations. List each item clearly and sequentially without using bullet points or numbering. Use consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, make a reasonable assumption and provide the answer directly. Ensure the final line contains only the answer in the following format: 'Answer: $VALUE' where VALUE is a numerical value.\n",
      "sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1062.32it/s]\n",
      "Training step 3. Epoch 2: : 3it [07:08, 142.76s/it]\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(3):\n",
    "    for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):\n",
    "        pbar.set_description(f\"Training step {steps}. Epoch {epoch}\")\n",
    "        optimizer.zero_grad()\n",
    "        losses = []\n",
    "        for (x, y) in zip(batch_x, batch_y):\n",
    "            x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n",
    "            y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n",
    "            response = model(x)\n",
    "            try:\n",
    "                eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n",
    "            except:\n",
    "                eval_output_variable = eval_fn([x, y, response])\n",
    "            losses.append(eval_output_variable)\n",
    "        total_loss = tg.sum(losses)\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        run_validation_revert(system_prompt, results, model, eval_fn, val_set)\n",
    "        \n",
    "        print(\"sys prompt: \", system_prompt)\n",
    "        test_acc = eval_dataset(test_set, eval_fn, model)\n",
    "        results[\"test_acc\"].append(test_acc)\n",
    "        results[\"prompt\"].append(system_prompt.get_value())\n",
    "        if steps == 3:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dab7a53-e682-478e-9417-15009b495979",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
